"""
Parallel trainers that share single decord dataloader.
"""
import os
import sys
import time
import pickle
import shutil
import argparse
import numpy as np
from queue import Empty
import multiprocessing as mp
from paddle import fluid
from tensorboardX import SummaryWriter

sys.path.append(
    os.path.realpath(os.path.join(os.path.dirname(__file__), '..')))
from config import XiaoduHiConfig
from interaction.attention_ctrl import AttentionController
from interaction.common.data_v2 import CV2Reader
from interaction.common.data_via_decord import DecordReader, Detector, \
    PostWorker


Config = {
    'visual_token': ('de_visual_token', 1),  # (sub save_dir, gpu)
    'instance': ('de_instance', 1),
    'without_inst_fm': ('de_instance_wo_fm', 1),
    'without_inst_cls': ('de_instance_wo_cls', 1),
    'without_inst_pos': ('de_instance_wo_pos', 1)
}


def parse_args():
    parser = argparse.ArgumentParser(
        description='Parallel trainer for attention controller.')

    data_group = parser.add_argument_group('data')
    data_group.add_argument(
        '--wae-dir', '-wd', type=str, default='data/raw_wae',
        help='Directory of multimodal action embeddings, generated by '
        'scripts/collect_v2_act_emb.py.')
    data_group.add_argument(
        '--train-dataset', '-tr', type=str,
        default='data/xiaoduHi_train_v2.pkl',
        help='Path to training dataset which is generated by '
        'scripts/prepare_data.py.')
    data_group.add_argument(
        '--full-neg-txt', type=str,
        default='data/full_neg_valid_i0.80_s0.10.txt',
        help='Path to text file recording paths of full negative frames.')
    data_group.add_argument(
        '--pos-video-dir', type=str, default='data/xiaodu_clips_v2',
        help='Directory of preprocessed video obj-tracking results. '
        'It is generated by scripts/collect_v2_data.py')

    model_group = parser.add_argument_group('model')
    model_group.add_argument(
        '--yolov4-model-dir', type=str,
        default='tools/yolov4_paddle/inference_model',
        help='Directory of YOLOv4 model.')
    model_group.add_argument(
        '--num-decoder-blocks', '-nb', type=int, default=6,
        help='Number of transformer decoder blocks of the attention model.')
    model_group.add_argument(
        '--num-heads', '-nh', type=int, default=8,
        help='Number of heads for multihead attention.')
    model_group.add_argument(
        '--model-dim', '-md', type=int, default=512,
        help='Dimension of key, query and value in self-attention.')
    model_group.add_argument(
        '--ffn-dim', '-fd', type=int, default=2048,
        help='Dimension of feedforward network after self-attention.')
    model_group.add_argument(
        '--normalize-before', default=False, action='store_true',
        help='Whether to use layer normalization at the beginning of '
        'each transformer decoder block.')
    model_group.add_argument(
        '--frame-emb-trainable', default=False, action='store_true',
        help='Whether frame embedding is trainable or use 1-D sine '
        'postional embedding as original transformer paper.')
    model_group.add_argument(
        '--use-last-act-loss', default=False, action='store_true',
        help='Whether to compute action loss only for last frame.')

    train_group = parser.add_argument_group('train')
    train_group.add_argument(
        '--trigger-loss-coef', type=float, default=5.0,
        help='Coefficent of trigger loss.')
    train_group.add_argument(
        '--obj-loss-coef', type=float, default=1.0,
        help='Coefficent of interaction object loss.')
    train_group.add_argument(
        '--act-loss-coef', type=float, default=1.0,
        help='Coefficent of multimodal action NNL loss.')
    train_group.add_argument(
        '--epochs', type=int, default=10,
        help='The number of epochs to train the core controller.')
    train_group.add_argument(
        '--dropout', type=float, default=0.0,
        help='Probability of dropout, default 0.0, meaning no dropout.')
    train_group.add_argument(
        '--lr', type=float, default=0.0001, help='Learning rate.')
    train_group.add_argument(
        '--bs', type=int, default=8, help='Batch size.')
    train_group.add_argument(
        '--save', type=str, default='save',
        help='Directory to save parameters and log files.')

    decord_group = parser.add_argument_group('decord')
    decord_group.add_argument(
        '--decord-readers', type=int, default=8,
        help='Number of parallel decord readers to use.')
    decord_group.add_argument(
        '--decord-detectors', type=int, default=2,
        help='Number of yolov4 detectors to use.')
    decord_group.add_argument(
        '--decord-post-workers', type=int, default=4,
        help='Number of post workers to use.')
    decord_group.add_argument(
        '--detector-gpus', type=str, default='0',
        help='Config for detector gpus.')
    decord_group.add_argument(
        '--decord-ds-pkl', type=str, default='data/xiaoduHi_decord.pkl',
        help='Path to decord dataset pkl, generated by '
        'scripts/prepare_dataset.py -vt v2_decord')
    decord_group.add_argument(
        '--queue-max-size', type=int, default=100,
        help='Maximum size of queue.')
    decord_group.add_argument(
        '--dataloader-timeout', type=int, default=60,
        help='Timeout to treat the queue as an end of an epoch.')

    parser.add_argument(
        '--use-cv2', default=False, action='store_true',
        help='Whether to use cv2 as a replacement of decord. '
        'This is faster when videos are exported as frames.')

    return parser.parse_args()


def train_epoch(data_queue, inputs_type, update_step, exe, program,
                attention_ctrl, log_steps=1, log_file=None, tb_writer=None):
    feed_list = [i.name for i in attention_ctrl.feed_list]

    batch_id = 0
    t = time.time()
    while True:
        data = data_queue.get()  # batch data as dict
        if len(data) == 0:
            break

        batch_id += 1
        feed_dict = {k: data[k] for k in feed_list}
        total_loss, trigger_loss, obj_loss, act_loss = \
            exe.run(program, feed=feed_dict,
                    fetch_list=[attention_ctrl.loss,
                                attention_ctrl.trigger_loss,
                                attention_ctrl.obj_loss,
                                attention_ctrl.act_loss])

        dt = time.time() - t
        t = time.time()
        update_step += 1
        if update_step % log_steps == 0:
            log_str = '[{}-{}, {:.2f}s] Total loss: {:.4f}'.format(
                inputs_type, update_step, dt, total_loss[0])
            log_str += ', Trigger loss: {:.4f}'.format(trigger_loss[0])
            log_str += ', Obj loss: {:.4f}'.format(obj_loss[0])
            log_str += ', Act loss: {:.4f}'.format(act_loss[0])
            print(log_str)

        if tb_writer is not None:
            tb_writer.add_scalar('total_loss', total_loss[0], update_step)
            tb_writer.add_scalar('trigger_loss', trigger_loss[0], update_step)
            tb_writer.add_scalar('obj_loss', obj_loss[0], update_step)
            tb_writer.add_scalar('act_loss', act_loss[0], update_step)

        if log_file is not None:
            with open(log_file, 'a') as f:
                f.write('{:.4f},{:.4f},{:.4f},{:.4f}\n'.format(
                    total_loss[0], trigger_loss[0], obj_loss[0],
                    act_loss[0]))

    return update_step


def trainer_func(data_queue, inputs_type, gpu, sub_save_dir, args, cfg):
    wae_ndarray = np.load(os.path.join(args.wae_dir, 'raw_wae.npy'))
    start_epoch, update_step = 0, 0

    train_program = fluid.Program()
    startup_program = fluid.Program()

    with fluid.program_guard(train_program, startup_program):
        attention_ctrl = AttentionController(
            inputs_type=inputs_type,
            num_actions=wae_ndarray.shape[0],
            act_tr_dim=wae_ndarray.shape[1],
            act_emb_ndarray=wae_ndarray,
            num_frames=cfg.ob_window_len,
            tokens_per_frame=cfg.tokens_per_frame,
            visual_token_dim=cfg.visual_token_dim,
            model_dim=args.model_dim,
            num_decoder_blocks=args.num_decoder_blocks,
            num_heads=args.num_heads,
            ffn_dim=args.ffn_dim,
            dropout=args.dropout,
            normalize_before=args.normalize_before,
            frame_emb_trainable=args.frame_emb_trainable,
            trigger_loss_coef=args.trigger_loss_coef,
            obj_loss_coef=args.obj_loss_coef,
            act_loss_coef=args.act_loss_coef,
            use_last_act_loss=args.use_last_act_loss,
            mode='train')
        preds = attention_ctrl.predict()
        test_program = train_program.clone(for_test=True)

        optimizer = fluid.optimizer.AdamOptimizer(
            learning_rate=args.lr,
            regularization=fluid.regularizer.L2Decay(
                regularization_coeff=0.1))
        optimizer.minimize(attention_ctrl.loss)

    place = fluid.CUDAPlace(gpu)
    exe = fluid.Executor(place)
    exe.run(startup_program)

    save_dir = os.path.join(args.save, sub_save_dir)
    os.makedirs(save_dir, exist_ok=True)

    train_log = os.path.join(save_dir, 'loss.csv')
    tb_writer = SummaryWriter(logdir=os.path.join(save_dir, 'logdir'))

    for epoch_id in range(args.epochs):
        update_step = train_epoch(
            data_queue, inputs_type, update_step, exe, train_program,
            attention_ctrl, log_file=train_log, tb_writer=tb_writer)

        ep_dir = os.path.join(save_dir, 'epoch_{}'.format(epoch_id))
        shutil.rmtree(ep_dir, ignore_errors=True)
        os.mkdir(ep_dir)
        fluid.io.save_params(exe, ep_dir, main_program=train_program)

        tb_state = os.path.join(ep_dir, 'tb_state.txt')
        with open(tb_state, 'w') as f:
            f.write('{}'.format(update_step))


def sample_neg_train_ids(data):
    neg_r1, neg_r2 = [], []
    for i in range(len(data['neg'])):
        if data['neg'][i]['VideoType'] == 'r2':
            neg_r2.append(i)
        elif data['neg'][i]['VideoType'] == 'r1':
            neg_r1.append(i)

    np.random.seed(0)
    ids = np.arange(len(neg_r2))
    np.random.shuffle(ids)

    num_test = len(data['pos_test'])
    num_train = len(data['pos_train'])
    n = int(ids.shape[0] * num_test / (num_test + num_train + 0.0))
    neg_train_ids = [i for i in list(ids[n:])]
    neg_train_ids.extend(neg_r1)

    ids = np.arange(len(neg_train_ids))
    np.random.shuffle(ids)
    neg_train_ids = [neg_train_ids[i] for i in list(ids)]
    neg_train_offset = 0

    return neg_train_ids, neg_train_offset


def get_annos_per_reader(data, reader_id, readers, neg_train_ids,
                                neg_train_offset):
    num_pos = len(data['pos_train'])
    num_neg = num_pos

    n, o = len(neg_train_ids), neg_train_offset
    pos_anno_lst = [data['pos_train'][i] for i in range(num_pos)
                    if i % readers == reader_id]
    neg_anno_lst = [data['neg'][neg_train_ids[(i + o) % n]]
                    for i in range(num_neg)
                    if i % readers == reader_id]

    anno_lst = pos_anno_lst + neg_anno_lst
    ids = np.arange(len(anno_lst))
    np.random.shuffle(ids)
    anno_lst = [anno_lst[i] for i in ids]
    return anno_lst


def convert_to_feed(batch):
    feed = dict()
    for k in batch[0].keys():
        feed[k] = np.stack([i[k] for i in batch]).astype(batch[0][k].dtype)

    return feed


def load_dataset(args):
    if not args.use_cv2:
        with open(args.decord_ds_pkl, 'rb') as f:
            data = pickle.load(f)
            neg_train_ids, neg_train_offset = sample_neg_train_ids(data)
    else:
        data = dict()
        with open(args.train_dataset, 'rb') as f:
            pos_anno_lst = pickle.load(f)

        data['pos_train'] = []
        for anno in pos_anno_lst:
            anno['Video'] = os.path.join(
                args.pos_video_dir, '{}.mp4'.format(anno['VideoID']))
            data['pos_train'].append(anno)

        data['neg'] = []
        with open(args.full_neg_txt, 'r') as f:
            for path in f.readlines():
                path = os.path.realpath(path.strip())
                if os.path.isdir(path):
                    null_anno = {'WAE_id': 0, 'VideoType': 'neg_frames'}
                    null_anno['Path'] = path
                    data['neg'].append(null_anno)

        np.random.seed(0)
        ids = np.arange(len(data['neg']))
        np.random.shuffle(ids)
        n = int(ids.shape[0] * 0.8)
        neg_train_ids = [i for i in list(ids[:n])]
        neg_train_offset = 0

    return data, neg_train_ids, neg_train_offset


def main(args):
    cfg = XiaoduHiConfig()
    cfg.scene_sensor_algo = 'yolov4'

    data, neg_train_ids, neg_train_offset = load_dataset(args)

    # Create process queue
    proc_manager = mp.Manager()
    read_frame_queue = proc_manager.Queue(args.queue_max_size)
    process_inst_queue = proc_manager.Queue(args.queue_max_size)
    dataloader_queue = proc_manager.Queue(args.queue_max_size)

    # Create trainers
    trainer_Qs = []
    for inputs_type in Config.keys():
        sub_save_dir, gpu = Config[inputs_type]
        Q = proc_manager.Queue(args.queue_max_size)
        trainer_Qs.append(Q)
        trainer_proc = mp.Process(
            target=trainer_func,
            args=(Q, inputs_type, gpu, sub_save_dir, args, cfg))
        trainer_proc.start()

    # Create data pipeline
    detector_gpus = [int(i) for i in args.detector_gpus.split(',')]
    reader_lst, detector_lst, pw_lst = [], [], []
    for idx in range(args.decord_readers):
        if not args.use_cv2:
            reader_lst.append(
                DecordReader(idx, proc_manager, read_frame_queue, []))
        else:
            reader_lst.append(
                CV2Reader(args.pos_video_dir, idx, proc_manager,
                          read_frame_queue, []))
    for idx in range(args.decord_detectors):
        gpu = detector_gpus[idx % len(detector_gpus)]
        detector_lst.append(
            Detector(gpu, args.yolov4_model_dir, proc_manager,
                     read_frame_queue, process_inst_queue))
    for idx in range(args.decord_post_workers):
        pw_lst.append(PostWorker(
            proc_manager, process_inst_queue, dataloader_queue))

    # Start data workers
    for detector in detector_lst:
        detector.start()
        time.sleep(5)  # wait for model loading
    for post_worker in pw_lst:
        post_worker.start()
    for reader in reader_lst:
        reader.start()

    for epoch_id in range(args.epochs):
        # update anno lst
        for idx, reader in enumerate(reader_lst):
            reader.update(
                get_annos_per_reader(data, idx, args.decord_readers,
                                     neg_train_ids, neg_train_offset))
        neg_train_offset += len(data['pos_train'])

        for reader in reader_lst:
            reader.next_epoch()

        batch = []
        while True:
            try:
                data = dataloader_queue.get(timeout=args.dataloader_timeout)
            except Empty:
                break

            batch.append(data)
            if len(batch) == args.bs:
                feed_dict = convert_to_feed(batch)
                for Q in trainer_Qs:
                    Q.put(feed_dict)

                batch = []

        if len(batch) > 0:
            feed_dict = convert_to_feed(batch)
            for Q in trainer_Qs:
                Q.put(feed_dict)

        for Q in trainer_Qs:
            # Make sure that trainer exists this epoch and saves model
            Q.put({})

        print('==============================')
        print('Finished epoch {}'.format(epoch_id))

    # Stop data workers
    for reader in reader_lst:
        reader.stop()
    for detector in detector_lst:
        detector.stop()
    for post_worker in pw_lst:
        post_worker.stop()


if __name__ == '__main__':
    if len(sys.argv) == 1:
        sys.argv.append('-h')
    args = parse_args()
    mp.set_start_method('spawn')
    main(args)
